{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Electra\n",
    "参考链接：    \n",
    "https://github.com/google-research/electra   \n",
    "https://github.com/ymcui/Chinese-ELECTRA  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Electra` 是 `Efficiently Learning an Encoder that Classifies Token Replacements Accurately` 的缩写；\n",
    "\n",
    "是基于`BERT`进行改进的模型，抛弃传统的`MLM（masked language model）`任务，提出了全新的`replaced token detection`任务，使得模型在保持性能的前提下大大降低了模型参数量，提高了模型的运算速度。\n",
    "\n",
    "   - 能够高效地学习如何将收集来的句子进行准确分词，即 `token-replacement`\n",
    "   - 只需要`RoBERTa`和`XLNet`四分之一的计算量，就能在`GLUE`上达到它们的性能"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预训练任务\n",
    "现存的预训练模型主要分为两大类：语言模型 (`Language Model,LM`)和掩码语言模型 (`Masked Language Model,MLM`)。\n",
    "- `GPT`就是一种`LM`，它从左到右处理输入文本，根据给定的上下文预测下一个单词。\n",
    "- `BERT、RoBERTa`和`ALBERT`属于`MLM`，它们可以预测输入中被掩盖的少量单词。MLM具有双向的优势，它们可以“看到”要预测的`token`两侧的文本。\n",
    "    - 这些模型只预测了输入的很小的子集(被掩盖的`15%`)，从而减少了从每个句子中获得的信息量。\n",
    "----------\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`ELECTRA`使用的是一种新的预训练任务，叫做`replaced token detection (RTD)`。像`MLM`一样训练一个双向模型，也像`LM`一样学习所有输入位置。通过使用不正确的(但有些可信的)伪`token`替换一些输入`token`.      \n",
    "![](../images/ELECTRA任务.gif)\n",
    "\n",
    "- 首先mask一些input tokens，使用一个生成器预测句中被mask掉的token，\n",
    "- 接下来使用预测的token替代句中的[MASK]标记，然后使用一个判别器区分句中的每个token是原始的还是替换后的。  \n",
    "![](../images/ELECTRA预测任务.jpg)\n",
    "-------------------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 目标函数：\n",
    "$$min_{\\theta_G,\\theta_D}\\sum L_{MLM}(x,\\theta_G)+\\lambda L_{Disc}(x,\\theta_D)$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 加号左边代表MLM的loss，右边代表discriminator的loss。在预训练的时候， generator和discrimiator同时训练。\n",
    "\n",
    "Generator网络其实就是一个小型MLM，discriminator就是论文所说的ELECTRA模型。在预训练完成之后，generator被丢弃，而discriminator被保留用来做下游任务的基础模型。\n",
    "- MLM仅从15%被mask的tokens学习，而replaced token detection要辨别inputs的所有tokens的“真假”，因而可以学习到所有tokens；\n",
    "- MLM任务中[mask]的存在导致了预训练和fine-tuning数据分布不匹配的问题，而这个问题在ELECTRA模型中不存在。尽管MLM做了一些措施来弥补，但是并没有完全解决这个问题。\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型技巧\n",
    "- 权重共享:\n",
    "    - 论文尝试对generator和discriminator做了两种权重共享：token embeddings共享，以及所有权重共享。实验得到，在保持generator和discriminator大小相同的情况下，不共享权重的GLUE score是83.6，共享token embeddings的GLUE score是84.3，共享所有权重的score是84.4。论文分析，这是因为generator对token embedding有着更好的学习能力，因此共享token embeddings后discriminator也能获得更好的token embeddings。\n",
    "- 更小的Generator\n",
    "    - 如果保持generator和discriminator模型大小一样，ELECTRA大约要花费MLM预训练的两倍计算时间，因此论文提出使用小size的generator。generator的大小为discriminator的1/4-1/2时效果最好\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
